Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented Sep 30, 2024

This PR:
1 - updates the batch_decode_next_tokens in a way that handles both llama2 and llama3 with their respective tokenizers. Thus we have a single universal decoding for in flight.

2 - adds a temperature option to enable non-deterministic (creative) decoding, using topk and multinomial selection.

3 - update decode_in_flight, again to be compat with both llama2 and llama3.

4 - minor tweaks to using zip for final display and move decode_in_flight to _decode_in_flight with other global functions for ease of reference.

Tested with both llama2 and llama3:
example multiprompt with llama2:

Prompt: What is Snow? 
09-30 12:23:45.151 - dist_run:548 - Response: 
Snow is a form of frozen water that falls from the sky during winter months. It is created when water vapor in the air freezes into ice crystals, which then 
09-30 12:23:45.151 - dist_run:547 - Prompt: Who is Santa Claus? 
09-30 12:23:45.151 - dist_run:548 - Response: 
Santa Claus is a legendary figure associated with Christmas and the tradition of gift-giving during the holiday season. The character is based 

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1234

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1e2eec0 with merge base 7ad9ba2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 30, 2024
@metascroy
Copy link
Contributor

Rebase to fix failing tochao_experimental check

@lessw2020 lessw2020 changed the title Implement universal batch_decode & decode_in_flight for llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) [Distributed] Implement universal batch_decode & decode_in_flight for llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) Sep 30, 2024
@lessw2020 lessw2020 requested a review from kwen2501 September 30, 2024 20:09
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new feature! LGTM. Just some minor comments.

) -> torch.Tensor:
"""
Decode the next token for each prompt in the batch.
Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if torchchat's generate also have the temperature option? Shall we think about how to connect with generate in next steps?

Comment on lines +217 to +218
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we do the division unconditionally?

return next_token
batch_size, seq_len, vocab_size = output.shape

if step != -1:
Copy link
Contributor

@kwen2501 kwen2501 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can step == -1 be represented by pos = [] or pos = [0, 0, ...]? (saving one argument)

Comment on lines +222 to +228
top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size
top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)
next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1)
next_tokens = top_k_indices.gather(
-1, next_token_indices.unsqueeze(-1)
).squeeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you mind adding more comments here for the multinomial, gather, squeeze and unsqueeze ops?

# New token generated each iteration
# need a row dimension for each prompt in the batch
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for debugging only?

Comment on lines 463 to +464
if not args.disable_in_flight_decode:
decode_in_flight(new_token)
_decode_in_flight(new_token, tokenizer, tp_rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put tp_rank to the if condition?

# Decode the output
if pp_rank == last_pp_rank:
new_token = _batch_decode_next_tokens(output, 0)
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove log?

Comment on lines +536 to +537
# Decode the output as comprehension instead of loop
responses = [tokenizer.decode(sequence) for sequence in res_list]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, did the previous code not work in case of variable length? Just curious.
response = tokenizer.decode(res_list)

next_token_logits = output[:, 0, :]
else:
# get the logits for each prompt at the specified positions
next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why "-1"?
From this function's perspective, if the caller has given the position, should it just faithfully decode that position?
(I understand that this can be run right if providing prompt_length instead of promt_length -1 at callsite.)

@lessw2020 lessw2020 merged commit 77bac00 into main Oct 2, 2024
52 checks passed
@lessw2020
Copy link
Contributor Author

  • will revisit to update re: comments. Landing for now so we can work on tp fix.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants